import json
import logging
import os
import re
import time
import base64
import numpy as np
import argparse

from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple

from openai import OpenAI

from gym_wrapper import CoopCommandGymEnv

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('qwen_eval.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# API Configurations
API_CONFIGS = {
    "qwen": {
        "api_key_env": "DASHSCOPE_API_KEY",
        "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
        "default_model": "qwen2.5-omni-7b",
        "vision_support": True,
        "audio_support": False  # Through text description
    },
    "openai": {
        "api_key_env": "OPENAI_API_KEY", 
        "base_url": None,  # Use default OpenAI endpoint
        "default_model": "gpt-4o",
        "vision_support": True,
        "audio_support": False  # Through text description
    },
    "gemini": {
        "api_key_env": "GEMINI_API_KEY",
        "base_url": None,
        "default_model": "gemini-2.0-flash",
        "vision_support": True,
        "audio_support": False
    }
}

class MultiProviderEvaluator:
    """Multi-provider evaluator for cooperative command game supporting Qwen and OpenAI."""
    
    def __init__(self, difficulty: str = "normal", seed_index: int = 0, 
                 max_rounds: Optional[int] = None, enable_stream: bool = True,
                 save_media: bool = True, deterministic_commands: bool = True,
                 api_provider: str = "qwen", model_name: Optional[str] = None,
                 input_mode: str = "image_audio", include_vector_text: bool = True,
                 enhanced_video: bool = False, video_fps: float = 0.5, 
                 audio_duration_per_frame: float = 3.0, num_episodes: int = 10):
        """
        Initialize the multi-provider evaluator.
        
        Args:
            difficulty: Game difficulty ("normal", "medium", "hard")
            seed_index: Random seed index for reproducible results
            max_rounds: Maximum number of game rounds (auto-calculated if None)
            enable_stream: Enable streaming responses from the model
            save_media: Save images, audio, and responses to files
            deterministic_commands: Commands execute deterministically (vs probabilistic)
            api_provider: API provider ("qwen", "openai", "gemini")
            model_name: Model name (uses provider default if None)
            input_mode: Input mode ("image_audio" or "video")
            include_vector_text: Include vector information as text in prompt (default: True).
                                If False, model must interpret vector info from visual input only.
            enhanced_video: Enable enhanced video creation with integrated audio (default: False)
            video_fps: Frames per second for video recording (default: 0.5 for audio integration)
            audio_duration_per_frame: Expected audio duration per frame in seconds (default: 3.0)
            num_episodes: Number of episodes to evaluate (default: 10)
        """
        self.difficulty = difficulty
        self.seed_index = seed_index
        self.max_rounds = max_rounds
        self.enable_stream = enable_stream
        self.save_media = save_media
        self.deterministic_commands = deterministic_commands
        self.api_provider = api_provider.lower()
        self.input_mode = input_mode.lower()
        self.include_vector_text = include_vector_text
        self.enhanced_video = enhanced_video
        self.video_fps = video_fps
        self.audio_duration_per_frame = audio_duration_per_frame
        self.num_episodes = num_episodes  # 新增：episode数量
        
        # Validate input mode
        if self.input_mode not in ["image_audio", "video"]:
            raise ValueError(f"Unsupported input mode: {input_mode}. Choose from: ['image_audio', 'video']")
        
        # Validate API provider
        if self.api_provider not in API_CONFIGS:
            raise ValueError(f"Unsupported API provider: {api_provider}. Choose from: {list(API_CONFIGS.keys())}")
        
        # Set up API configuration
        self.api_config = API_CONFIGS[self.api_provider]
        self.model_name = model_name or self.api_config["default_model"]
        
        # Initialize API client
        self._init_api_client()
        
        # Create output directory for this evaluation
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        output_root = "outputs"
        os.makedirs(output_root, exist_ok=True)
        self.output_dir = Path(f"{output_root}/{self.api_provider}_eval_{difficulty}_seed{seed_index}_{num_episodes}ep_{timestamp}")
        if self.save_media:
            self.output_dir.mkdir(exist_ok=True)
            self.images_dir = self.output_dir / "images"
            self.audio_dir = self.output_dir / "audio"
            self.videos_dir = self.output_dir / "videos"
            self.responses_dir = self.output_dir / "responses"
            self.images_dir.mkdir(exist_ok=True)
            self.audio_dir.mkdir(exist_ok=True)
            self.videos_dir.mkdir(exist_ok=True)
            self.responses_dir.mkdir(exist_ok=True)
        
        # 新增：为每个episode创建子目录
        self.episode_dirs = {}
        if self.save_media:
            for ep in range(num_episodes):
                ep_dir = self.output_dir / f"episode_{ep:02d}"
                ep_dir.mkdir(exist_ok=True)
                self.episode_dirs[ep] = {
                    "root": ep_dir,
                    "images": ep_dir / "images",
                    "audio": ep_dir / "audio", 
                    "videos": ep_dir / "videos",
                    "responses": ep_dir / "responses"
                }
                # 创建子目录
                for subdir in ["images", "audio", "videos", "responses"]:
                    self.episode_dirs[ep][subdir].mkdir(exist_ok=True)
        
        # 不在初始化时创建环境，将在每个episode开始时创建
        self.env = None
        
        # Results tracking - 修改为支持多episode
        self.results = {
            "config": {
                "difficulty": difficulty,
                "seed_index": seed_index,
                "max_rounds": max_rounds,
                "api_provider": self.api_provider,
                "model": self.model_name,
                "input_mode": self.input_mode,
                "include_vector_text": self.include_vector_text,
                "enhanced_video": self.enhanced_video,
                "video_fps": self.video_fps,
                "audio_duration_per_frame": self.audio_duration_per_frame,
                "vision_support": self.api_config["vision_support"],
                "num_episodes": num_episodes,  # 新增
                "timestamp": datetime.now().isoformat(),
                "output_directory": str(self.output_dir) if self.save_media else None
            },
            "episodes": [],  # 修改：存储每个episode的结果
            "summary_stats": {},  # 新增：汇总统计
            "media_files": {
                "images": [],
                "audio": [],
                "videos": [],
                "responses": []
            },
            "command_compliance": {
                "total_turns": 0,
                "valid_single_commands": 0,
                "multiple_command_violations": 0,
                "no_command_found": 0,
                "compliance_rate": 0.0
            }
        }
        
        # Command types for reference
        self.command_types = ["move", "attack", "defend", "recon", "status"]
        
        # Store member list for multi-member commands
        self._last_member_list = []
        
        # Enhanced system prompt with movement error and hidden objectives
        command_reliability_note = """
⚙️ COMMAND EXECUTION:
- Commands execute deterministically - all valid commands will succeed
- Focus on strategic positioning and optimal task assignment
- No need to account for random command failures

""" if deterministic_commands else """
⚙️ COMMAND EXECUTION:
- Commands may randomly fail based on team member capability and status
- Scout: 85% success rate, Heavy: 75%, Medic: 90%, Engineer: 80%, Sniper: 70%
- Injured or low-health members have reduced success rates
- Plan for potential command failures and have backup strategies

"""
       
        # Build media description based on input mode and enhanced video settings
        if self.input_mode == "video":
            if self.enhanced_video:
                media_description = f"""- Video: Visual sequence with integrated audio (FPS: {actual_fps}, Audio per frame: {self.audio_duration_per_frame}s)
  Shows game state progression and team member movements with synchronized voice communications"""
            else:
                media_description = "- Video: Visual sequence showing game state progression and team member movements over time"
        else:
            media_description = """- Audio: Voice reports from team members about discoveries, status changes
- Image: Visual representation of current game state"""

        # Build information sources section based on include_vector_text setting
        if self.include_vector_text:
            info_sources = f"""📊 INFORMATION PROVIDED:
- Vector: Team member states (health, status, position) + global info (rounds remaining, normalized score)
{media_description}
- Discovery hints: Clues about nearby hidden objectives"""
        else:
            info_sources = f"""📊 INFORMATION PROVIDED:
{media_description}
- Discovery hints: Clues about nearby hidden objectives

⚠️ NOTE: Vector information is available visually - interpret team states from the visual input."""

        self.system_prompt = f"""
You are commanding a military team in a cooperative mission. You MUST provide EXACTLY ONE command per turn.

❌ FORBIDDEN: Multiple commands like "COMMAND: 0 move 20 30" AND "COMMAND: 1 recon 40 50"
✅ CORRECT: Only one command like "COMMAND: 0 move 20 30"

If you provide more than one command, the system will ERROR and use a default command instead.

KEY GAME MECHANICS:

{command_reliability_note}

🎯 HIDDEN OBJECTIVES:
- Some objectives are HIDDEN and not visible initially
- You must EXPLORE different areas to discover hidden objectives
- Scout team members have higher discovery probability (80% vs 40%)
- Send scouts to unexplored areas to find new objectives
- Discovery hints may indicate "unusual activity" in areas with hidden objectives

⚠️ MOVEMENT UNCERTAINTY:
- Team members DO NOT move to exact coordinates you specify
- Movement has ERROR based on:
  * Role precision (Scout: low error, Heavy: high error)
  * Health status (injured = more error)
  * Movement distance (longer moves = more error)
- Expect actual positions to deviate from your targets
- Plan for imprecise movement in your strategy

{info_sources}

🎮 STRATEGIC CONSIDERATIONS:
- Balance exploration (finding hidden objectives) vs completion (finishing known objectives)
- Use scouts for exploration and discovery
- Account for movement errors in positioning
- Monitor team health and status for optimal assignment
- Hidden objectives may have high score values - worth discovering!

🚨 COMMAND FORMAT - PROVIDE EXACTLY ONE OF THESE:

**Individual Command (one member):**
COMMAND: [member_id] [action] [x] [y]

**Team Command (all members together):**
COMMAND: all [action] [x] [y]

**Multi-member Command (specific members together):**
COMMAND: 0,1,2 [action] [x] [y]

**Available Actions:** move, attack, defend, recon, status
**Coordinates:** x, y: 0-100 (actual position will vary due to movement error)

EXAMPLES OF CORRECT RESPONSES:
✅ "Based on the current situation, I'll send the scout to explore. COMMAND: 0 recon 25 30"
✅ "The team should move together to the objective. COMMAND: all move 45 20"
✅ "Two scouts should explore this area. COMMAND: 0,1 recon 70 80"

EXAMPLES OF INCORRECT RESPONSES (WILL CAUSE ERRORS):
❌ "COMMAND: 0 move 25 30" followed by "COMMAND: 1 recon 45 20"
❌ Multiple command lines in any form
❌ Suggesting multiple commands for "efficient coordination"

🚨 FINAL REMINDER: ONE COMMAND ONLY! 🚨
- Analyze the situation thoroughly
- Choose the SINGLE most important action
- Provide exactly ONE command
- Plan step-by-step across multiple turns, not all at once

Provide your strategic analysis, then end with exactly ONE command.
"""

    def _init_api_client(self):
        """Initialize the API client based on the selected provider."""
        api_key = os.getenv(self.api_config["api_key_env"])
        if not api_key:
            raise ValueError(f"Missing API key. Please set {self.api_config['api_key_env']} environment variable.")
        
        # Create client with provider-specific configuration
        if self.api_config["base_url"]:
            self.client = OpenAI(
                api_key=api_key,
                base_url=self.api_config["base_url"]
            )
        else:
            self.client = OpenAI(api_key=api_key)
        
        logger.info(f"Initialized {self.api_provider.upper()} client with model: {self.model_name}")
        logger.info(f"Vision support: {self.api_config['vision_support']}")

    def _save_image(self, image_data: str, step: int) -> Optional[str]:
        """Save base64 image data to file and return relative path."""
        if not self.save_media or image_data is None or (hasattr(image_data, 'size') and image_data.size == 0):
            return None
        
        try:
            # Handle both base64 string and numpy array input
            if isinstance(image_data, str):
                # Decode base64 image
                image_bytes = base64.b64decode(image_data)
            elif hasattr(image_data, 'shape'):
                # Convert numpy array to base64
                from PIL import Image
                import io
                
                if len(image_data.shape) == 3:
                    # Convert numpy array to PIL Image
                    image_pil = Image.fromarray(image_data.astype(np.uint8))
                    
                    # Convert to base64
                    buffer = io.BytesIO()
                    image_pil.save(buffer, format='JPEG', quality=85)
                    image_bytes = buffer.getvalue()
                else:
                    logger.warning(f"Unexpected image shape: {image_data.shape}")
                    return None
            else:
                logger.warning(f"Unexpected image type: {type(image_data)}")
                return None
            
            # Save to file
            filename = f"step_{step:03d}_image.jpg"
            filepath = self.images_dir / filename
            
            with open(filepath, 'wb') as f:
                f.write(image_bytes)
            
            relative_path = f"images/{filename}"
            logger.debug(f"Saved image for step {step}: {relative_path}")
            return relative_path
            
        except Exception as e:
            logger.error(f"Failed to save image for step {step}: {e}")
            return None

    def _save_audio(self, audio_data, step: int) -> Optional[str]:
        """Save audio data to file and return relative path."""
        if not self.save_media or not audio_data:
            return None
        
        try:
            # Handle both base64 audio and JSON text audio
            if isinstance(audio_data, str):
                try:
                    # Try to decode as base64 first (real audio)
                    audio_bytes = base64.b64decode(audio_data)
                    
                    # Check if it looks like audio data (has some audio-like characteristics)
                    if len(audio_bytes) > 100:  # Audio files are typically larger
                        # Save as audio file
                        filename = f"step_{step:03d}_audio.mp3"
                        filepath = self.audio_dir / filename
                        
                        with open(filepath, 'wb') as f:
                            f.write(audio_bytes)
                        
                        relative_path = f"audio/{filename}"
                        logger.debug(f"Saved audio file for step {step}: {relative_path}")
                        return relative_path
                    else:
                        # Small data, probably not real audio
                        raise ValueError("Data too small to be audio")
                        
                except Exception:
                    # Not base64 audio, try as JSON text
                    try:
                        audio_json = json.loads(audio_data)
                        filename = f"step_{step:03d}_audio_text.json"
                        filepath = self.audio_dir / filename
                        
                        with open(filepath, 'w', encoding='utf-8') as f:
                            json.dump(audio_json, f, indent=2, ensure_ascii=False)
                        
                        relative_path = f"audio/{filename}"
                        logger.debug(f"Saved audio text for step {step}: {relative_path}")
                        return relative_path
                    except:
                        # Raw text
                        filename = f"step_{step:03d}_audio_text.txt"
                        filepath = self.audio_dir / filename
                        
                        with open(filepath, 'w', encoding='utf-8') as f:
                            f.write(audio_data)
                        
                        relative_path = f"audio/{filename}"
                        logger.debug(f"Saved raw audio text for step {step}: {relative_path}")
                        return relative_path
                        
            elif isinstance(audio_data, (list, dict)):
                # JSON data
                filename = f"step_{step:03d}_audio.json"
                filepath = self.audio_dir / filename
                
                with open(filepath, 'w', encoding='utf-8') as f:
                    json.dump(audio_data, f, indent=2, ensure_ascii=False)
                
                relative_path = f"audio/{filename}"
                logger.debug(f"Saved audio JSON for step {step}: {relative_path}")
                return relative_path
            
            return None
            
        except Exception as e:
            logger.error(f"Failed to save audio for step {step}: {e}")
            return None
        

    def _save_video(self, video_data, step: int, video_type: str = "auto") -> Optional[str]:
        """Save video data to file and return relative path."""
        if not self.save_media or video_data is None or (hasattr(video_data, 'size') and video_data.size == 0):
            return None
        
        try:
            # Handle both base64 string and raw bytes
            if isinstance(video_data, str):
                # Decode base64 data
                video_bytes = base64.b64decode(video_data)
            elif isinstance(video_data, bytes):
                # Already bytes
                video_bytes = video_data
            else:
                logger.warning(f"Unexpected video data type: {type(video_data)}")
                return None
            
            # Detect file type by checking magic bytes
            is_jpeg = video_bytes.startswith(b'\xff\xd8\xff')
            is_mp4 = (video_bytes[4:12] == b'ftypmp4' or 
                     video_bytes[4:12] == b'ftypisom' or
                     video_bytes[4:8] == b'ftyp' or
                     video_bytes.startswith(b'\x00\x00\x00') and b'ftyp' in video_bytes[:20])
            
            # Determine filename based on type and detection
            if video_type == "input":
                # This is API input content - use descriptive naming
                if is_mp4:
                    filename = f"step_{step:03d}_api_input.mp4"
                    logger.debug(f"Saving API input video (MP4) for step {step}")
                else:
                    filename = f"step_{step:03d}_api_input.jpg"
                    logger.debug(f"Saving API input frame (JPEG) for step {step}")
            else:
                # Auto-detect naming (original behavior)
                if is_jpeg:
                    filename = f"step_{step:03d}_frame.jpg"
                    logger.debug(f"Detected JPEG image data for step {step}, saving as {filename}")
                elif is_mp4:
                    filename = f"step_{step:03d}_video.mp4"
                    logger.debug(f"Detected MP4 video data for step {step}, saving as {filename}")
                else:
                    filename = f"step_{step:03d}_video.mp4"
                    logger.warning(f"Could not detect file type for step {step}, defaulting to MP4")
            
            filepath = self.videos_dir / filename
            
            with open(filepath, 'wb') as f:
                f.write(video_bytes)
            
            relative_path = f"videos/{filename}"
            logger.debug(f"Saved video/image for step {step}: {relative_path}")
            return relative_path
            
        except Exception as e:
            logger.error(f"Failed to save video for step {step}: {e}")
            return None

    def _save_api_input_video(self, observation, step: int) -> Optional[str]:
        """Create and save the video content that would be sent to API in video mode."""
        if not self.save_media or not isinstance(observation, dict):
            return None
        
        try:
            # Check if we can create a video input from the environment
            if hasattr(self.env, 'frame_buffer') and len(self.env.frame_buffer) > 0:
                # Create video clip from frame buffer - use enhanced video with audio if available
                if self.enhanced_video:
                    video_bytes = self.env._create_enhanced_video_clip_from_buffer()
                    video_type_desc = "enhanced API input video with audio"
                else:
                    video_bytes = self.env._create_video_clip_from_buffer()
                    video_type_desc = "API input video"
                
                if video_bytes:
                    # Save as API input video
                    video_path = self._save_video(video_bytes, step, video_type="input")
                    if video_path:
                        logger.info(f"Created {video_type_desc} from {len(self.env.frame_buffer)} frames for step {step}")
                        return video_path
            
            # Fallback: if no frame buffer, try to get current frame and save as image input
            if observation.get('image') is not None:
                image_data = observation['image']
                if isinstance(image_data, str):
                    # Already base64 encoded
                    image_base64 = image_data
                    # Convert numpy array to base64
                    from PIL import Image
                    import io
                    
                    if len(image_data.shape) == 3:
                        image_pil = Image.fromarray(image_data.astype(np.uint8))
                        buffer = io.BytesIO()
                        image_pil.save(buffer, format='JPEG', quality=85)
                        image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
                    else:
                        return None
                else:
                    return None
                
                if image_base64:
                    # Save as API input frame
                    video_path = self._save_video(image_base64, step, video_type="input")
                    if video_path:
                        logger.info(f"Saved API input frame for step {step}")
                        return video_path
            
            return None
            
        except Exception as e:
            logger.error(f"Failed to create API input video for step {step}: {e}")
            return None

    def _save_model_response(self, response: str, step: int) -> Optional[str]:
        """Save model response to text file and return relative path."""
        if not self.save_media or not response:
            return None
        
        try:
            filename = f"step_{step:03d}_response.txt"
            filepath = self.responses_dir / filename
            
            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(response)
            
            relative_path = f"responses/{filename}"
            logger.debug(f"Saved response for step {step}: {relative_path}")
            return relative_path
            
        except Exception as e:
            logger.error(f"Failed to save response for step {step}: {e}")
            return None

    def _get_state_description(self, observation) -> str:
        """Create human-readable state description from observation."""
        try:
            # Handle both dict and array observation formats
            if isinstance(observation, dict):
                # Multi-modal observation format
                vector_obs = observation['vector']
                audio_data = json.loads(observation.get('audio', '[]'))
                has_image = 'image' in observation
            else:
                # Simple vector observation format
                vector_obs = observation
                audio_data = []
                has_image = False
            
            # Ensure vector_obs is a numpy array
            if hasattr(vector_obs, 'shape'):
                vector_obs = vector_obs.flatten()  # Flatten if multi-dimensional
            
            # Parse vector observation
            num_members = (len(vector_obs) - 2) // 4
            description = [f"Team size: {num_members}"]
            
            # Member states
            for i in range(num_members):
                base_idx = i * 4
                health = float(vector_obs[base_idx])
                status_code = int(float(vector_obs[base_idx + 1]))
                x, y = float(vector_obs[base_idx + 2]), float(vector_obs[base_idx + 3])
                
                status_names = ["idle", "moving", "attacking", "defending", "recon", "dead", "injured"]
                status = status_names[status_code] if status_code < len(status_names) else "unknown"
                
                description.append(f"Member {i}: {health:.0f}% health, {status}, at ({x:.0f},{y:.0f})")
            
            # Global state
            rounds_remaining = int(float(vector_obs[-2]))
            score_normalized = float(vector_obs[-1])
            description.append(f"Rounds remaining: {rounds_remaining}")
            description.append(f"Score: {score_normalized:.1f}/100")
            
            # Audio events
            if audio_data:
                description.append(f"Audio: {', '.join(str(msg) for msg in audio_data)}")
            
            # Visual info
            if has_image:
                description.append("Visual: Game state image available")
            
            # Check for video in observation
            if isinstance(observation, dict) and observation.get('video') is not None:
                description.append("Video: Game state video sequence available")
            
            return "\n".join(description)
            
        except Exception as e:
            logger.error(f"Error creating state description: {e}")
            logger.error(f"Observation type: {type(observation)}")
            if hasattr(observation, 'shape'):
                logger.error(f"Observation shape: {observation.shape}")
            elif isinstance(observation, dict):
                logger.error(f"Observation keys: {list(observation.keys())}")
            return "State parsing failed"

    def _extract_command(self, response: str) -> Optional[np.ndarray]:
        """Extract command from model response, supporting both individual and team commands."""
        try:
            # Member name to index mapping (common military call signs)
            member_name_map = {
                'alpha': 0, 'bravo': 1, 'charlie': 2, 'delta': 3, 'echo': 4, 'foxtrot': 5,
                'scout': 0, 'heavy': 1, 'medic': 2, 'engineer': 3, 'sniper': 4, 'support': 5
            }
            
            # Remove markdown code blocks if present
            clean_response = re.sub(r'```[^`]*```', '', response)

            # Check for multiple VALID COMMAND instances first and warn
            # Use more precise regex to avoid markdown formatting like "Command:**"
            # Only count commands that have actual valid content (not just punctuation)
            # Only count commands starting line with COMMAND:
            valid_command_lines = []
            all_command_lines = re.findall(r"^COMMAND:\s*[^\n]+", clean_response, re.IGNORECASE)
            
            for cmd_line in all_command_lines:
                valid_command_lines.append(cmd_line)
            
            if len(valid_command_lines) > 1:
                logger.warning(f"⚠️ Multiple command lines detected ({len(valid_command_lines)}). Using first valid command.")
                logger.warning("Valid commands found:")
                for i, cmd_line in enumerate(valid_command_lines, 1):
                    logger.warning(f"  {i}. {cmd_line}")
            elif len(valid_command_lines) == 0:
                # Check if there are any COMMAND instances at all (including malformed ones)
                loose_command_check = re.findall(r"COMMAND:", clean_response, re.IGNORECASE)
                if len(loose_command_check) > 0:
                    logger.warning("Found COMMAND keyword but no valid command format detected")
                    # Log what was found for debugging
                    for cmd_line in all_command_lines:
                        logger.warning(f"Found: {cmd_line}")
            
            # Find all potential command patterns and take the first valid one
            command_patterns = [
                # Team-wide pattern
                (r"COMMAND:\s*all\s+(move|attack|defend|recon|status)\s+([\d.]+)\s+([\d.]+)", "team"),
                # Multi-member pattern  
                (r"COMMAND:\s*((?:\d+,)*\d+)\s+(move|attack|defend|recon|status)\s+([\d.]+)\s+([\d.]+)", "multi"),
                # Individual member pattern (numeric)
                (r"COMMAND:\s*(\d+)\s+(move|attack|defend|recon|status)\s+([\d.]+)\s+([\d.]+)", "individual"),
                # Member name pattern (fallback)
                (r"COMMAND:\s*(\w+)\s+(move|attack|defend|recon|status)\s+([\d.]+)\s+([\d.]+)", "name")
            ]
            
            for pattern, pattern_type in command_patterns:
                match = re.search(pattern, clean_response, re.IGNORECASE)
                if match:
                    if pattern_type == "team":
                        cmd_type = match.group(1).lower()
                        x, y = float(match.group(2)), float(match.group(3))
                        
                        # Validate and clamp values
                        cmd_idx = self.command_types.index(cmd_type) if cmd_type in self.command_types else 0
                        x = max(0, min(int(round(x)), 100))
                        y = max(0, min(int(round(y)), 100))
                        
                        logger.info(f"✅ Extracted team command: {cmd_type} to ({x},{y})")
                        return np.array([self.env.num_members, cmd_idx, x, y], dtype=np.int32)
                    
                    elif pattern_type == "multi":
                        member_list = match.group(1)
                        cmd_type = match.group(2).lower()
                        x, y = float(match.group(3)), float(match.group(4))
                        
                        # Parse member list
                        try:
                            member_indices = [int(m.strip()) for m in member_list.split(',')]
                            # Validate indices
                            member_indices = [max(0, min(idx, self.env.num_members - 1)) for idx in member_indices]
                            
                            if len(member_indices) > 1:
                                # Validate and clamp values
                                cmd_idx = self.command_types.index(cmd_type) if cmd_type in self.command_types else 0
                                x = max(0, min(int(round(x)), 100))
                                y = max(0, min(int(round(y)), 100))
                                
                                # Store the member list for the gym wrapper to use
                                self._last_member_list = member_indices
                                
                                logger.info(f"✅ Extracted multi-member command: {cmd_type} to ({x},{y}) for members {member_indices}")
                                return np.array([self.env.num_members + 1, cmd_idx, x, y], dtype=np.int32)
                        except ValueError:
                            continue  # Try next pattern
                    
                    elif pattern_type == "individual":
                        member_idx = int(match.group(1))
                        cmd_type = match.group(2).lower()
                        x, y = float(match.group(3)), float(match.group(4))
                        
                        # Validate and clamp values
                        member_idx = max(0, min(member_idx, self.env.num_members - 1))
                        cmd_idx = self.command_types.index(cmd_type) if cmd_type in self.command_types else 0
                        x = max(0, min(int(round(x)), 100))
                        y = max(0, min(int(round(y)), 100))
                        
                        logger.info(f"✅ Extracted individual command: {cmd_type} to ({x},{y}) for member {member_idx}")
                        return np.array([member_idx, cmd_idx, x, y], dtype=np.int32)
                    
                    elif pattern_type == "name":
                        member_name = match.group(1).lower()
                        cmd_type = match.group(2).lower()
                        x, y = float(match.group(3)), float(match.group(4))
                        
                        # Convert member name to index
                        member_idx = member_name_map.get(member_name, 0)
                        
                        # Validate and clamp values
                        member_idx = max(0, min(member_idx, self.env.num_members - 1))
                        cmd_idx = self.command_types.index(cmd_type) if cmd_type in self.command_types else 0
                        x = max(0, min(int(round(x)), 100))
                        y = max(0, min(int(round(y)), 100))
                        
                        logger.info(f"✅ Extracted name-based command: {cmd_type} to ({x},{y}) for {member_name} (index {member_idx})")
                        return np.array([member_idx, cmd_idx, x, y], dtype=np.int32)
            
            # If no standard patterns match, try fallback parsing
            fallback_match = re.search(r"COMMAND:\s*([^\n]+)", clean_response, re.IGNORECASE)
            if fallback_match:
                command_line = fallback_match.group(1).strip()
                logger.warning(f"⚠️ Using fallback parsing for: {command_line}")
                
                # Try to parse the command line directly
                parts = command_line.split()
                if len(parts) >= 4:
                    try:
                        # Try parsing as: member_id action x y
                        member_part = parts[0]
                        cmd_type = parts[1].lower()
                        x, y = float(parts[2]), float(parts[3])
                        
                        # Handle member specification
                        if member_part.lower() == 'all':
                            member_idx = self.env.num_members
                        elif ',' in member_part:
                            # Multi-member command
                            member_indices = [int(m.strip()) for m in member_part.split(',')]
                            member_indices = [max(0, min(idx, self.env.num_members - 1)) for idx in member_indices]
                            if len(member_indices) > 1:
                                self._last_member_list = member_indices
                                member_idx = self.env.num_members + 1
                            else:
                                member_idx = member_indices[0] if member_indices else 0
                        else:
                            # Single member
                            try:
                                member_idx = int(member_part)
                            except ValueError:
                                # Try member name lookup
                                member_idx = member_name_map.get(member_part.lower(), 0)
                        
                        # Validate and return
                        if member_idx != self.env.num_members and member_idx != self.env.num_members + 1:
                            member_idx = max(0, min(member_idx, self.env.num_members - 1))
                        cmd_idx = self.command_types.index(cmd_type) if cmd_type in self.command_types else 0
                        x = max(0, min(int(round(x)), 100))
                        y = max(0, min(int(round(y)), 100))
                        
                        logger.info(f"✅ Fallback extraction successful: {cmd_type} to ({x},{y})")
                        return np.array([member_idx, cmd_idx, x, y], dtype=np.int32)
                        
                    except (ValueError, IndexError) as e:
                        logger.warning(f"❌ Fallback parsing failed: {e}")
            
            logger.warning("⚠️ No valid command found in response. Using default: move to center.")
            logger.debug("📝 Response content preview: " + clean_response[:200] + "...")
            return np.array([0, 0, 50, 50], dtype=np.int32)  # Default: member 0 move to center
            
        except Exception as e:
            logger.error(f"💥 Command extraction crashed: {e}")
            return np.array([0, 0, 50, 50], dtype=np.int32)

    def _query_model(self, observation, step: int) -> Tuple[str, np.ndarray, Dict]:
        """Query the model for command suggestion and save media files."""
        media_paths = {"image": None, "audio": None, "video": None, "response": None, "api_input": None}
        
        try:
            # Prepare content
            if self.include_vector_text:
                state_desc = self._get_state_description(observation)
                state_section = f"""Current game state:
{state_desc}

"""
                vector_info_line = "- Vector: Team member states (health, status, position) + global info (rounds remaining, normalized score)\n"
            else:
                state_section = ""
                vector_info_line = ""
            
            # Customize prompt based on input mode
            if self.input_mode == "video":
                media_info = "Video: Visual sequence showing game state progression (visual only)\nAudio: Tactical guidance and team communications (separate from video)"
            else:
                media_info = "Image: Visual representation of current game state\nAudio: Tactical guidance and team communications"
            
            text_content = f"""{state_section}🚨🚨🚨 CRITICAL REMINDER: EXACTLY ONE COMMAND ONLY! 🚨🚨🚨

You MUST provide exactly ONE command in your response. Multiple commands will cause SYSTEM ERRORS!

❌ DO NOT DO THIS: Provide multiple "COMMAND:" lines
✅ DO THIS: Provide exactly one "COMMAND:" line

Choose the SINGLE most important action for this turn. You can plan additional moves for future turns.

Available inputs:
{vector_info_line}- {media_info}
- Discovery hints: Clues about nearby hidden objectives

Analyze the situation and provide your ONE command."""
            
            content_parts = [{"type": "text", "text": text_content}]
            
            # Handle media input based on input mode
            if self.input_mode == "video":
                # Video input mode (visual only)
                if (self.api_config["vision_support"] and 
                    isinstance(observation, dict) and 
                    observation.get('video') is not None):
                    
                    video_data = observation['video']
                    if isinstance(video_data, str):
                        # Already base64 encoded
                        video_base64 = video_data
                        media_paths["video"] = self._save_video(video_base64, step)
                        
                        # Detect if this is actual video or fallback image
                        video_bytes_test = base64.b64decode(video_base64)
                        is_actual_video = (video_bytes_test[4:12] == b'ftypmp4' or 
                                         video_bytes_test[4:12] == b'ftypisom' or
                                         video_bytes_test[4:8] == b'ftyp')
                        
                        if is_actual_video:
                            content_parts.append({
                                "type": "video_url",
                                "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}
                            })
                            logger.info(f"Added visual-only MP4 video input for {self.api_provider} model")
                        else:
                            content_parts.append({
                                "type": "image_url", 
                                "image_url": {"url": f"data:image/jpeg;base64,{video_base64}"}
                            })
                            logger.info(f"Added visual frame input for {self.api_provider} model")
                
                # Add separate audio input if available and model supports it
                if isinstance(observation, dict) and observation.get('audio'):
                    audio_data = observation['audio']
                    media_paths["audio"] = self._save_audio(audio_data, step)
                    
                    # For now, add audio info as text since most APIs don't support direct audio
                    try:
                        # Try to parse as JSON first
                        if isinstance(audio_data, str) and audio_data.startswith('{'):
                            audio_json = json.loads(audio_data)
                            if audio_json.get("guidance"):
                                audio_text = f"Tactical guidance: {audio_json['guidance']}"
                                if audio_json.get("team_communications"):
                                    comms = ", ".join(audio_json["team_communications"])
                                    audio_text += f"\nTeam communications: {comms}"
                                
                                content_parts.append({
                                    "type": "text",
                                    "text": f"\n[AUDIO INPUT]\n{audio_text}\n[END AUDIO]"
                                })
                                logger.info(f"Added tactical audio guidance as text for {self.api_provider} model")
                    except:
                        # Handle as raw audio/text
                        if isinstance(audio_data, str) and len(audio_data) < 1000:  # Probably text
                            content_parts.append({
                                "type": "text", 
                                "text": f"\n[AUDIO INPUT]\n{audio_data}\n[END AUDIO]"
                            })
                            logger.info(f"Added audio content as text for {self.api_provider} model")
            
            else:
                # Image + Audio input mode (default)
                # Handle image data
                if (self.api_config["vision_support"] and 
                    isinstance(observation, dict) and 
                    observation.get('image') is not None):

                    # Check if image is base64 string or numpy array
                    image_data = observation['image']
                    if isinstance(image_data, str):
                        # Already base64 encoded
                        image_base64 = image_data
                    elif hasattr(image_data, 'shape'):
                        # Numpy array - convert to base64
                        from PIL import Image
                        import io
                        
                        if len(image_data.shape) == 3:
                            # Convert numpy array to PIL Image
                            image_pil = Image.fromarray(image_data.astype(np.uint8))
                            
                            # Convert to base64
                            buffer = io.BytesIO()
                            image_pil.save(buffer, format='JPEG', quality=85)
                            image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
                        else:
                            logger.warning(f"Unexpected image shape: {image_data.shape}")
                            image_base64 = None
                    else:
                        logger.warning(f"Unexpected image type: {type(image_data)}")
                        image_base64 = None
                    
                    if image_base64:
                        media_paths["image"] = self._save_image(image_base64, step)
                        content_parts.append({
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
                        })
                        logger.info(f"Added visual input for {self.api_provider} model")
                
                # Save video content if available (from enhanced video recording)
                if isinstance(observation, dict) and observation.get('video') is not None:
                    video_data = observation['video']
                    media_paths["video"] = self._save_video(video_data, step)
                    logger.info(f"Saved visual-only video content (step {step}) - available in image_audio mode")
                
                # Handle separate audio data
                if isinstance(observation, dict) and observation.get('audio'):
                    audio_data = observation['audio']
                    media_paths["audio"] = self._save_audio(audio_data, step)
                    
                    # Add audio content as text (similar to video mode)
                    try:
                        if isinstance(audio_data, str) and audio_data.startswith('{'):
                            audio_json = json.loads(audio_data)
                            if audio_json.get("guidance"):
                                audio_text = f"Tactical guidance: {audio_json['guidance']}"
                                if audio_json.get("team_communications"):
                                    comms = ", ".join(audio_json["team_communications"])
                                    audio_text += f"\nTeam communications: {comms}"
                                
                                content_parts.append({
                                    "type": "text",
                                    "text": f"\n[AUDIO INPUT]\n{audio_text}\n[END AUDIO]"
                                })
                                logger.info(f"Added tactical audio guidance as text")
                    except:
                        if isinstance(audio_data, str) and len(audio_data) < 1000:
                            content_parts.append({
                                "type": "text",
                                "text": f"\n[AUDIO INPUT]\n{audio_data}\n[END AUDIO]"
                            })
                            logger.info(f"Added audio content as text")
            
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": content_parts}
            ]
            
            logger.info(f"Querying {self.api_provider.upper()} model: {self.model_name}")
            
            # Create completion with provider-specific parameters
            completion_params = {
                "model": self.model_name,
                "messages": messages,
                "stream": self.enable_stream
            }
            
            # Add provider-specific parameters
            if self.api_provider == "openai":
                completion_params["max_tokens"] = 1000  # OpenAI models benefit from explicit max_tokens
            
            if self.enable_stream:
                response = self.client.chat.completions.create(**completion_params)
                
                full_response = ""
                for chunk in response:
                    if chunk.choices and chunk.choices[0].delta.content:
                        content = chunk.choices[0].delta.content
                        full_response += content
                        print(content, end='', flush=True)
                print()  # New line after streaming
            else:
                response = self.client.chat.completions.create(**completion_params)
                full_response = response.choices[0].message.content
                logger.info(f"Model response: {full_response}")
            
            # Save model response
            media_paths["response"] = self._save_model_response(full_response, step)
            
            # Save API input video content (what would be sent in video mode)
            media_paths["api_input"] = self._save_api_input_video(observation, step)
            
            # Extract command
            command = self._extract_command(full_response)
            return full_response, command, media_paths
            
        except Exception as e:
            logger.error(f"Model query failed for {self.api_provider}: {e}")
            import traceback
            logger.error(f"Full traceback: {traceback.format_exc()}")
            return f"Error: {e}", np.array([0, 0, 50, 50], dtype=np.int32), media_paths

    def _create_env_for_episode(self, episode_idx: int) -> CoopCommandGymEnv:
        """为指定episode创建新的环境实例"""
        try:
            # 计算该episode的seed：基础seed + episode索引，确保每个episode独立
            episode_seed = self.seed_index + episode_idx * 1000
            
            # 确定录制目录
            if self.save_media and episode_idx in self.episode_dirs:
                recordings_dir = str(self.episode_dirs[episode_idx]["videos"])
            else:
                recordings_dir = "recordings"
            
            # 确定录制模式
            if self.input_mode == "video":
                recording_mode = "video"
            elif self.enhanced_video:
                recording_mode = "both"
            else:
                recording_mode = "video" if self.save_media else "individual"
            
            actual_fps = self.video_fps if self.enhanced_video else 1
            
            # 创建新的环境实例
            env = CoopCommandGymEnv(
                difficulty=self.difficulty,
                seed_index=episode_seed,  # 使用计算出的episode seed
                max_rounds=self.max_rounds,
                enable_audio=True,
                enable_visual=True,
                deterministic_commands=self.deterministic_commands,
                recording_mode=recording_mode,
                video_fps=actual_fps,
                enhanced_video=self.enhanced_video,
                audio_duration_per_frame=self.audio_duration_per_frame,
                recordings_dir=recordings_dir
            )
            
            logger.info(f"Created environment for episode {episode_idx} with seed {episode_seed}")
            return env
            
        except Exception as e:
            logger.error(f"Failed to create environment for episode {episode_idx}: {e}")
            raise

    def run_evaluation(self) -> Dict:
        """Run the multi-episode evaluation."""
        logger.info(f"Starting multi-episode evaluation - Provider: {self.api_provider.upper()}, Model: {self.model_name}")
        logger.info(f"Episodes: {self.num_episodes}, Difficulty: {self.difficulty}, Base Seed: {self.seed_index}")
        if self.save_media:
            logger.info(f"Media files will be saved to: {self.output_dir}")
        
        episode_results = []
        overall_command_compliance = {
            "total_turns": 0,
            "valid_single_commands": 0,
            "multiple_command_violations": 0,
            "no_command_found": 0,
            "compliance_rate": 0.0
        }
        
        # 运行每个episode
        for episode_idx in range(self.num_episodes):
            logger.info(f"\n{'='*60}")
            logger.info(f"🎮 STARTING EPISODE {episode_idx + 1}/{self.num_episodes}")
            logger.info(f"{'='*60}")
            
            try:
                # 创建新的环境实例
                if self.env:
                    self.env.close()  # 清理之前的环境
                self.env = self._create_env_for_episode(episode_idx)
                
                # 运行单个episode
                episode_result = self._run_single_episode(episode_idx)
                episode_results.append(episode_result)
                
                # 更新整体指标
                ep_compliance = episode_result.get("command_compliance", {})
                overall_command_compliance["total_turns"] += ep_compliance.get("total_turns", 0)
                overall_command_compliance["valid_single_commands"] += ep_compliance.get("valid_single_commands", 0)
                overall_command_compliance["multiple_command_violations"] += ep_compliance.get("multiple_command_violations", 0)
                overall_command_compliance["no_command_found"] += ep_compliance.get("no_command_found", 0)
                
                # 记录episode结果
                logger.info(f"✅ Episode {episode_idx + 1} completed:")
                logger.info(f"   Score: {episode_result['final_stats']['final_score_normalized']:.1f}/100")
                logger.info(f"   Steps: {episode_result['final_stats']['total_steps']}")
                logger.info(f"   Objectives: {episode_result['final_stats']['objectives_completed']}/{episode_result['final_stats']['total_objectives']}")
                logger.info(f"   Success Rate: {episode_result['final_stats']['success_rate']:.1%}")
                
            except Exception as e:
                logger.error(f"💥 Episode {episode_idx + 1} failed: {e}")
                logger.error(f"Full traceback: {traceback.format_exc()}")
                
                # 创建失败的episode结果
                episode_result = {
                    "episode_index": episode_idx,
                    "episode_seed": self.seed_index + episode_idx * 1000,
                    "status": "failed",
                    "error": str(e),
                    "steps": [],
                    "final_stats": {
                        "total_steps": 0,
                        "total_reward": 0.0,
                        "final_score_normalized": 0.0,
                        "objectives_completed": 0,
                        "total_objectives": 0,
                        "success_rate": 0.0,
                        "terminated": False,
                        "truncated": True
                    },
                    "command_compliance": {
                        "total_turns": 0,
                        "valid_single_commands": 0,
                        "multiple_command_violations": 0,
                        "no_command_found": 0,
                        "compliance_rate": 0.0
                    }
                }
                episode_results.append(episode_result)
            
            finally:
                # 清理当前episode的环境
                if self.env:
                    try:
                        self.env.close()
                    except:
                        pass
                    self.env = None
        
        # 计算整体合规率
        if overall_command_compliance["total_turns"] > 0:
            overall_command_compliance["compliance_rate"] = (
                overall_command_compliance["valid_single_commands"] / 
                overall_command_compliance["total_turns"]
            )
        
        # 计算汇总统计
        summary_stats = self._calculate_summary_stats(episode_results)
        
        # 更新结果
        self.results["episodes"] = episode_results
        self.results["summary_stats"] = summary_stats
        self.results["command_compliance"] = overall_command_compliance
        
        # 输出最终报告
        self._log_final_report(summary_stats, overall_command_compliance)
        
        return self.results

    def _run_single_episode(self, episode_idx: int) -> Dict:
        """运行单个episode的评估"""
        # 重置当前episode的追踪变量
        current_episode_compliance = {
            "total_turns": 0,
            "valid_single_commands": 0,
            "multiple_command_violations": 0,
            "no_command_found": 0,
            "compliance_rate": 0.0
        }
        
        # 重置环境
        try:
            observation, info = self.env.reset()
            self.results["config"]["actual_max_rounds"] = info.get("max_rounds", self.max_rounds)
            logger.debug(f"Episode {episode_idx} environment reset successful")
        except Exception as e:
            logger.error(f"Episode {episode_idx} environment reset failed: {e}")
            raise
        
        episode_seed = self.seed_index + episode_idx * 1000
        logger.info(f"Episode {episode_idx + 1} - Score: {info.get('score_normalized', 0):.1f}/100, "
                   f"Max rounds: {info.get('max_rounds', 0)}, Seed: {episode_seed}")
        
        step_count = 0
        total_reward = 0
        steps = []
        
        while True:
            step_count += 1
            logger.debug(f"Episode {episode_idx + 1}, Step {step_count}")
            
            # 获取当前episode的媒体保存目录
            current_images_dir = self.episode_dirs[episode_idx]["images"] if episode_idx in self.episode_dirs else self.images_dir
            current_audio_dir = self.episode_dirs[episode_idx]["audio"] if episode_idx in self.episode_dirs else self.audio_dir
            current_videos_dir = self.episode_dirs[episode_idx]["videos"] if episode_idx in self.episode_dirs else self.videos_dir
            current_responses_dir = self.episode_dirs[episode_idx]["responses"] if episode_idx in self.episode_dirs else self.responses_dir
            
            # 临时修改保存目录
            orig_dirs = None
            if self.save_media:
                orig_dirs = (self.images_dir, self.audio_dir, self.videos_dir, self.responses_dir)
                self.images_dir = current_images_dir
                self.audio_dir = current_audio_dir  
                self.videos_dir = current_videos_dir
                self.responses_dir = current_responses_dir
            
            try:
                # Query model and save media files
                model_response, command, media_paths = self._query_model(observation, step_count)
                
                # 恢复原始目录
                if orig_dirs:
                    self.images_dir, self.audio_dir, self.videos_dir, self.responses_dir = orig_dirs
                
                # Validate single command requirement and track compliance
                current_episode_compliance["total_turns"] += 1
                
                # Use improved command detection that filters out markdown formatting
                valid_command_lines = []
                all_command_lines = re.findall(r"COMMAND:\s*[^\n]+", model_response, re.IGNORECASE)
                
                for cmd_line in all_command_lines:
                    # Filter out markdown formatting that isn't a real command
                    if not re.match(r"COMMAND:\s*[\*\#\-\`]+", cmd_line, re.IGNORECASE):
                        # Check if it contains actual command content
                        # Updated pattern to handle multi-member commands with commas (e.g., "0,1 move")
                        if re.search(r"COMMAND:\s*(?:\w+(?:,\w+)*|all)\s+\w+", cmd_line, re.IGNORECASE):
                            valid_command_lines.append(cmd_line)
                
                command_count = len(valid_command_lines)
                
                if command_count > 1:
                    current_episode_compliance["multiple_command_violations"] += 1
                    logger.warning(f"⚠️ Multiple valid commands detected ({command_count}). Using first valid command.")
                    logger.info("Valid commands found:")
                    for i, cmd in enumerate(valid_command_lines, 1):
                        logger.info(f"  {i}. {cmd}")
                        
                elif command_count == 0:
                    current_episode_compliance["no_command_found"] += 1
                    logger.warning("⚠️ No valid COMMAND found in response. Using default command.")
                else:
                    current_episode_compliance["valid_single_commands"] += 1
                    logger.debug(f"✅ Valid single command detected")
                
                # Set member list for multi-member commands before execution
                if hasattr(self, '_last_member_list') and self._last_member_list and command[0] == self.env.num_members + 1:
                    self.env.set_multi_member_list(self._last_member_list)
                
                # Execute command
                obs, reward, terminated, truncated, info = self.env.step(command)
                total_reward += reward
                
                # Generate proper command description based on command type
                try:
                    member_idx = int(command[0]) if len(command) > 0 else 0
                    cmd_idx = int(command[1]) if len(command) > 1 else 0
                    x = int(command[2]) if len(command) > 2 else 50
                    y = int(command[3]) if len(command) > 3 else 50
                    
                    # Validate command index
                    if 0 <= cmd_idx < len(self.command_types):
                        cmd_type = self.command_types[cmd_idx]
                    else:
                        cmd_type = self.command_types[0]  # Default to first command type
                        
                except (IndexError, ValueError, TypeError) as e:
                    logger.error(f"Error parsing command array: {e}, command: {command}")
                    member_idx, cmd_idx, x, y = 0, 0, 50, 50
                    cmd_type = self.command_types[0]
                
                if member_idx == self.env.num_members:
                    # Team-wide command
                    command_desc = f"{cmd_type} to ({x},{y}) by all team members"
                elif member_idx == self.env.num_members + 1:
                    # Multi-member command
                    command_desc = f"{cmd_type} to ({x},{y}) by multiple members"
                else:
                    # Individual member command
                    command_desc = f"{cmd_type} to ({x},{y}) by member {member_idx}"
                
                # Log step results with media paths
                step_result = {
                    "step": step_count,
                    "command": command.tolist(),
                    "command_desc": command_desc,
                    "reward": float(reward),
                    "total_reward": float(total_reward),
                    "score_normalized": float(info.get('score_normalized', 0)),
                    "rounds_remaining": info.get('rounds_remaining', 0),
                    "objectives_completed": info.get('objectives_completed', 0),
                    "model_response_length": len(model_response),
                    "terminated": terminated,
                    "truncated": truncated,
                    "media_paths": media_paths  # Add media file paths
                }
                
                steps.append(step_result)
                
                logger.debug(f"Episode {episode_idx + 1}, Step {step_count}: {command_desc}")
                logger.debug(f"Reward: {reward:.2f}, Total: {total_reward:.2f}, Score: {info.get('score_normalized', 0):.1f}/100")
                
                # Update observation
                observation = obs
                
                # Check termination
                if terminated or truncated:
                    logger.info(f"Episode {episode_idx + 1} ended - Terminated: {terminated}, Truncated: {truncated}")
                    break
                    
            except Exception as e:
                logger.error(f"Error in episode {episode_idx + 1}, step {step_count}: {e}")
                # 恢复目录（如果出错）
                if orig_dirs:
                    self.images_dir, self.audio_dir, self.videos_dir, self.responses_dir = orig_dirs
                raise
        
        # Calculate episode compliance rate
        if current_episode_compliance["total_turns"] > 0:
            current_episode_compliance["compliance_rate"] = (
                current_episode_compliance["valid_single_commands"] / 
                current_episode_compliance["total_turns"]
            )
        
        # Final statistics for this episode
        final_stats = {
            "total_steps": step_count,
            "total_reward": float(total_reward),
            "final_score_normalized": float(info.get('score_normalized', 0)),
            "objectives_completed": info.get('objectives_completed', 0),
            "total_objectives": info.get('total_objectives', 0),
            "success_rate": info.get('objectives_completed', 0) / max(1, info.get('total_objectives', 1)),
            "terminated": terminated,
            "truncated": truncated
        }
        
        return {
            "episode_index": episode_idx,
            "episode_seed": episode_seed,
            "status": "completed",
            "steps": steps,
            "final_stats": final_stats,
            "command_compliance": current_episode_compliance
        }

    def _calculate_summary_stats(self, episode_results: List[Dict]) -> Dict:
        """计算所有episode的汇总统计"""
        if not episode_results:
            return {}
        
        # 收集所有已完成episode的统计数据
        completed_episodes = [ep for ep in episode_results if ep.get("status") == "completed"]
        failed_episodes = [ep for ep in episode_results if ep.get("status") == "failed"]
        
        if not completed_episodes:
            return {
                "total_episodes": len(episode_results),
                "completed_episodes": 0,
                "failed_episodes": len(failed_episodes),
                "success_rate": 0.0
            }
        
        # 提取各项指标
        scores = [ep["final_stats"]["final_score_normalized"] for ep in completed_episodes]
        steps = [ep["final_stats"]["total_steps"] for ep in completed_episodes]
        objectives_completed = [ep["final_stats"]["objectives_completed"] for ep in completed_episodes]
        total_objectives = [ep["final_stats"]["total_objectives"] for ep in completed_episodes]
        episode_success_rates = [ep["final_stats"]["success_rate"] for ep in completed_episodes]
        
        return {
            "total_episodes": len(episode_results),
            "completed_episodes": len(completed_episodes),
            "failed_episodes": len(failed_episodes),
            "completion_rate": len(completed_episodes) / len(episode_results),
            
            # 分数统计
            "score_stats": {
                "mean": np.mean(scores),
                "std": np.std(scores),
                "min": np.min(scores),
                "max": np.max(scores),
                "median": np.median(scores)
            },
            
            # 步数统计
            "steps_stats": {
                "mean": np.mean(steps),
                "std": np.std(steps),
                "min": np.min(steps),
                "max": np.max(steps),
                "median": np.median(steps)
            },
            
            # 目标完成统计
            "objectives_stats": {
                "total_completed": sum(objectives_completed),
                "total_available": sum(total_objectives),
                "mean_completed_per_episode": np.mean(objectives_completed),
                "mean_success_rate": np.mean(episode_success_rates),
                "episodes_with_100_percent": sum(1 for rate in episode_success_rates if rate >= 1.0),
                "episodes_with_50_percent_plus": sum(1 for rate in episode_success_rates if rate >= 0.5)
            },
            
            # 整体成功指标
            "overall_success_rate": np.mean(episode_success_rates),
            "consistency": 1.0 - (np.std(scores) / 100.0),  # 一致性指标 (分数标准差的倒数)
        }

    def _log_final_report(self, summary_stats: Dict, overall_command_compliance: Dict):
        """输出最终评估报告"""
        logger.info(f"\n{'='*80}")
        logger.info(f"🏆 MULTI-EPISODE EVALUATION REPORT")
        logger.info(f"{'='*80}")
        
        # 基础信息
        logger.info(f"API Provider: {self.api_provider.upper()}, Model: {self.model_name}")
        logger.info(f"Episodes: {summary_stats.get('total_episodes', 0)} total, "
                   f"{summary_stats.get('completed_episodes', 0)} completed, "
                   f"{summary_stats.get('failed_episodes', 0)} failed")
        logger.info(f"Completion Rate: {summary_stats.get('completion_rate', 0):.1%}")
        
        # 分数统计
        if "score_stats" in summary_stats:
            score_stats = summary_stats["score_stats"]
            logger.info(f"\n📊 SCORE STATISTICS:")
            logger.info(f"  Mean Score: {score_stats['mean']:.1f}/100 (±{score_stats['std']:.1f})")
            logger.info(f"  Score Range: {score_stats['min']:.1f} - {score_stats['max']:.1f}")
            logger.info(f"  Median Score: {score_stats['median']:.1f}/100")
        
        # 目标完成统计
        if "objectives_stats" in summary_stats:
            obj_stats = summary_stats["objectives_stats"]
            logger.info(f"\n🎯 OBJECTIVES STATISTICS:")
            logger.info(f"  Overall Success Rate: {obj_stats['mean_success_rate']:.1%}")
            logger.info(f"  Total Objectives Completed: {obj_stats['total_completed']}/{obj_stats['total_available']}")
            logger.info(f"  Episodes with 100% Success: {obj_stats['episodes_with_100_percent']}/{summary_stats['total_episodes']}")
            logger.info(f"  Episodes with 50%+ Success: {obj_stats['episodes_with_50_percent_plus']}/{summary_stats['total_episodes']}")
        
        # 效率统计
        if "steps_stats" in summary_stats:
            steps_stats = summary_stats["steps_stats"]
            logger.info(f"\n⚡ EFFICIENCY STATISTICS:")
            logger.info(f"  Mean Steps per Episode: {steps_stats['mean']:.1f} (±{steps_stats['std']:.1f})")
            logger.info(f"  Steps Range: {steps_stats['min']:.0f} - {steps_stats['max']:.0f}")
        
        # 一致性评估
        consistency = summary_stats.get('consistency', 0)
        logger.info(f"\n🎲 CONSISTENCY ANALYSIS:")
        logger.info(f"  Performance Consistency: {consistency:.1%}")
        if consistency >= 0.8:
            logger.info(f"  ✅ High consistency - Stable performance across episodes")
        elif consistency >= 0.6:
            logger.info(f"  ⚠️ Moderate consistency - Some performance variation")
        else:
            logger.info(f"  ❌ Low consistency - High performance variation")
        
        # 命令合规性报告
        compliance = overall_command_compliance
        logger.info(f"\n🔧 COMMAND COMPLIANCE REPORT:")
        logger.info(f"  Total Turns: {compliance['total_turns']}")
        logger.info(f"  Valid Single Commands: {compliance['valid_single_commands']}")
        logger.info(f"  Multiple Command Violations: {compliance['multiple_command_violations']}")
        logger.info(f"  No Command Found: {compliance['no_command_found']}")
        logger.info(f"  Overall Compliance Rate: {compliance['compliance_rate']:.1%}")
        
        if compliance['compliance_rate'] < 1.0:
            violation_rate = 100 - compliance['compliance_rate']*100
            logger.warning(f"  ⚠️ Model violated single command constraint in {violation_rate:.1f}% of turns!")
        else:
            logger.info(f"  ✅ Perfect command compliance achieved across all episodes!")
        
        logger.info(f"{'='*80}")

    def save_results(self, filename: Optional[str] = None) -> str:
        """Save multi-episode results to JSON file."""
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"{self.api_provider}_eval_{self.difficulty}_seed{self.seed_index}_{self.num_episodes}ep_{timestamp}.json"
        
        # Save in output directory if media saving is enabled
        if self.save_media:
            filepath = self.output_dir / "results.json"
        else:
            filepath = Path(filename)
        
        # Add summary of media files to results
        if self.save_media:
            total_images = sum(len(ep.get("steps", [])) for ep in self.results["episodes"])
            total_videos = len([ep for ep in self.results["episodes"] if ep.get("status") == "completed"])
            
            self.results["media_summary"] = {
                "total_episodes": len(self.results["episodes"]),
                "total_images_across_episodes": total_images,
                "total_videos_across_episodes": total_videos, 
                "output_directory": str(self.output_dir),
                "episode_directories": {
                    f"episode_{i:02d}": str(self.episode_dirs[i]["root"]) 
                    for i in range(len(self.results["episodes"]))
                    if i in self.episode_dirs
                }
            }
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(self.results, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Multi-episode results saved to: {filepath.absolute()}")
        if self.save_media:
            logger.info(f"Media files directory structure:")
            logger.info(f"  Root: {self.output_dir.absolute()}")
            logger.info(f"  Episodes: {len(self.results['episodes'])} episode subdirectories")
            logger.info(f"  Total Steps: {sum(len(ep.get('steps', [])) for ep in self.results['episodes'])}")
        
        return str(filepath)

    def close(self):
        """Clean up resources."""
        self.env.close()

def parse_args():
    parser = argparse.ArgumentParser(description="Multi-Episode Cooperative Command Game Evaluation")
    parser.add_argument("--difficulty", type=str, default="normal", choices=["normal", "medium", "hard"])
    parser.add_argument("--seed_index", type=int, default=0)
    parser.add_argument("--save_media", action="store_true")
    parser.add_argument("--probabilistic_commands", action="store_true")
    parser.add_argument("--api_provider", type=str, default="qwen", choices=["qwen", "openai"],
                        help="API provider to use")
    parser.add_argument("--model_name", type=str, default=None,
                        help="Model name (defaults to provider's default model)")
    parser.add_argument("--no_stream", action="store_true",
                        help="Disable streaming responses")
    parser.add_argument("--input_mode", type=str, default="video", 
                        choices=["image_audio", "video"],
                        help="Input modality mode: 'image_audio' for separate image and audio inputs, 'video' for video input")
    parser.add_argument("--no_vector_text", action="store_true",
                        help="Exclude vector information from text prompt (rely on visual interpretation only)")
    parser.add_argument("--enhanced_video", action="store_true",
                        help="Enable enhanced video recording with integrated audio")
    parser.add_argument("--video_fps", type=float, default=0.5,
                        help="Frames per second for video recording (default: 0.5 for audio integration)")
    parser.add_argument("--audio_duration_per_frame", type=float, default=3.0,
                        help="Expected audio duration per frame in seconds (default: 3.0)")
    # 新增参数
    parser.add_argument("--num_episodes", type=int, default=10,
                        help="Number of episodes to evaluate (default: 10)")
    return parser.parse_args()

def main():
    args = parse_args()
    difficulty = args.difficulty
    seed_index = args.seed_index
    save_media = args.save_media
    probabilistic_commands = args.probabilistic_commands
    api_provider = args.api_provider
    model_name = args.model_name
    enable_stream = not args.no_stream
    input_mode = args.input_mode
    include_vector_text = not args.no_vector_text
    enhanced_video = args.enhanced_video
    video_fps = args.video_fps
    audio_duration_per_frame = args.audio_duration_per_frame
    num_episodes = args.num_episodes  # 新增
    
    # Display configuration
    print(f"\n🚀 Multi-Episode Evaluation Configuration")
    print(f"📍 API Provider: {api_provider.upper()}")
    if model_name:
        print(f"🤖 Model: {model_name}")
    else:
        print(f"🤖 Model: {API_CONFIGS[api_provider]['default_model']} (default)")
    print(f"🎮 Difficulty: {difficulty.upper()}, Base Seed: {seed_index}")
    print(f"🔄 Episodes: {num_episodes}")  # 新增
    print(f"💾 Save Media: {'Yes' if save_media else 'No'}")
    print(f"🎯 Command Execution: {'Probabilistic' if probabilistic_commands else 'Deterministic'}")
    print(f"🔄 Streaming: {'Enabled' if enable_stream else 'Disabled'}")
    print(f"🎬 Input Mode: {input_mode.upper()}")
    print(f"📊 Include Vector Text: {'Yes' if include_vector_text else 'No'}")
    print(f"🎥 Enhanced Video: {'Yes' if enhanced_video else 'No'}")
    print(f"🎥 Video FPS: {video_fps}")
    print(f"🎤 Audio Duration per Frame: {audio_duration_per_frame}s")
    
    # Check vision support
    if API_CONFIGS[api_provider]["vision_support"]:
        if input_mode == "video":
            print(f"👁️  Vision: Supported (video input)")
        else:
            print(f"👁️  Vision: Supported (image + audio input)")
    else:
        print(f"👁️  Vision: Not supported (text-only mode)")
    
    print("=" * 60)
    
    # Run evaluation
    evaluator = MultiProviderEvaluator(
        difficulty=difficulty,
        seed_index=seed_index,
        enable_stream=enable_stream,
        save_media=save_media,
        deterministic_commands=not probabilistic_commands,
        api_provider=api_provider,
        model_name=model_name,
        input_mode=input_mode,
        include_vector_text=include_vector_text,
        enhanced_video=enhanced_video,
        video_fps=video_fps,
        audio_duration_per_frame=audio_duration_per_frame,
        num_episodes=num_episodes  # 新增
    )
    
    try:
        results = evaluator.run_evaluation()
        filepath = evaluator.save_results()
        
        print(f"\n✅ Multi-episode evaluation completed!")
        print(f"📄 Results saved to: {filepath}")
        print(f"📈 Episodes: {results['summary_stats'].get('completed_episodes', 0)}/{results['summary_stats'].get('total_episodes', 0)}")
        print(f"🏆 Mean Score: {results['summary_stats'].get('score_stats', {}).get('mean', 0):.1f}/100")
        print(f"🎯 Overall Success Rate: {results['summary_stats'].get('objectives_stats', {}).get('mean_success_rate', 0):.1%}")
        print(f"🎮 Provider: {api_provider.upper()}, Model: {evaluator.model_name}")
        
    except KeyboardInterrupt:
        print("\n⛔ Multi-episode evaluation interrupted by user")
    except Exception as e:
        logger.error(f"💥 Multi-episode evaluation failed: {e}")
        import traceback
        logger.error(f"Full traceback: {traceback.format_exc()}")
    finally:
        if hasattr(evaluator, 'env') and evaluator.env:
            evaluator.env.close()

if __name__ == "__main__":
    main()